home *** CD-ROM | disk | FTP | other *** search
/ Cream of the Crop 1 / Cream of the Crop 1.iso / PROGRAM / BPNN132U.ARJ / BPTRAIN.C < prev    next >
Text File  |  1992-04-28  |  6KB  |  272 lines

  1. /*
  2. *-----------------------------------------------------------------------------
  3. *    file:    bptrain.c
  4. *    desc:    back propagation Multi Layer Perceptron (MLP) training
  5. *    by:    patrick ko
  6. *    date:    02 aug 1991
  7. *    revi:    v1.32u 26 apr 1992
  8. *-----------------------------------------------------------------------------
  9. */
  10.  
  11. #include <stdio.h>
  12. #include <stdlib.h>
  13. #ifdef __TURBOC__
  14. #include <mem.h>
  15. #include <alloc.h>
  16. #endif
  17.  
  18. #include "nntype.h"
  19. #include "nncreat.h"
  20. #include "nntrain.h"
  21. #include "nnerror.h"
  22. #include "cparser.h"
  23. #include "bptrainv.h"
  24. #include "timer.h"
  25.  
  26. #define MAXHIDDEN    128
  27.  
  28. static INTEGER    hiddencnt = 0;
  29. static INTEGER    hidden[MAXHIDDEN];
  30. static INTEGER    output;
  31. static INTEGER    input;
  32. static INTEGER    totalhidden;
  33. static INTEGER    totalpatt = 0;
  34. static REAL    trainerr = ERROR_DEFAULT;
  35. static INTEGER    report = 0;
  36. static INTEGER    timer = 0;
  37. static long int    tdump = 0;
  38.  
  39. static VECTOR    **inputvect;
  40. static VECTOR    **targtvect;
  41.  
  42. extern REAL    TOLER;
  43.  
  44. static char    tname[128];
  45. /*
  46. *    dump file name with default
  47. */
  48. static char    dname[128] = "bptrain.dmp";
  49. static char    dinname[128] = "";
  50.  
  51. int    usage( )
  52. {
  53.     printf( "%s %s - by %s\n", PROGNAME, VERSION, AUTHOR );
  54.     printf( "Copyright (c) 1992 All Rights Reserved. %s\n\n", DATE );
  55.     printf( "Description: backprop neural net training with adaptive coefficients\n");
  56.     printf( "Usage: %s @file -i=# -o=# -hh=# {-h=#} -samp=# -ftrain=<fn>\n", PROGNAME);
  57.     printf( "[-fdump=<fn>] [-fdumpin=<fn>] -r=# [-t] [-tdump=#] [-w+=# -w-=#]\n" );
  58.     printf( "[-err=] [-torerr=] [// ...]\n");
  59.     printf( "Example: " );
  60.     printf( "create and train a 2x4x3x1 dimension NN with 10 samples\n");
  61.     printf( "%s -i=2 -o=1 -hh=2 -h=4 -h=3 -err=0.01 ", PROGNAME );
  62.     printf( "-ftrain=input.trn -samp=10\n" );
  63.     printf( "Where:\n" );
  64.     printf( "-i=,-o=     dimension of input/output layer\n" );
  65.     printf( "-hh=        number of hidden layers\n" );
  66.     printf( "-h=         each hidden layer dimension (may be multiple)\n" );
  67.     printf( "-ftrain=    name of train file containing inputs and targets\n" );
  68.     printf( "-fdump=     name of output weights dump file\n" );
  69.     printf( "-fdumpin=   name of input weights dump file (if any)\n");
  70.     printf( "-samp=      number of train input patterns in train file\n" );
  71.     printf( "-r=         report training status interval\n" );
  72.     printf( "-t          time the training (good for non-Unix)\n" );
  73.     printf( "-tdump=     time for periodic dump (specify seconds)\n");
  74.     printf( "-w+=        initial random weight upper bound\n" );
  75.     printf( "-w-=        initial random weight lower bound\n" );
  76.     printf( "-err=       mean square per unit train error ");
  77.     printf( "(def=%f)\n", ERROR_DEFAULT );
  78.     printf( "-torerr=    tolerance error (def=%f)\n", TOLER_DEFAULT);
  79.     exit (0);
  80. }
  81.  
  82. int    parse( )
  83. {
  84.     int    cmd;
  85.     char    rest[128];
  86.     int    resti;
  87.     long    restl;
  88.  
  89.     while ((cmd = cmdget( rest ))!= -1)
  90.         {
  91.         resti = atoi(rest);
  92.         restl = atol(rest);
  93.         switch (cmd)
  94.             {
  95.             case CMD_DIMINPUT:
  96.                 input = resti; break;
  97.             case CMD_DIMOUTPUT:
  98.                 output = resti; break;
  99.             case CMD_DIMHIDDENY:
  100.                 if (input <= 0 || output <= 0)
  101.                     {
  102.                     error( NNIOLAYER );
  103.                     }
  104.                 if (resti > MAXHIDDEN)
  105.                     {
  106.                     error( NN2MANYLAYER );
  107.                     }
  108.                 totalhidden = resti; break;
  109.             case CMD_DIMHIDDEN:
  110.                 if (hiddencnt >= totalhidden)
  111.                     {
  112.                     /*
  113.                     * hidden layers more than specified
  114.                     */
  115.                     break;
  116.                     }
  117.                 hidden[hiddencnt++] = resti;
  118.                 break;
  119.             case CMD_TRAINFILE:
  120.                 strcpy( tname, rest );
  121.                 break;
  122.             case CMD_TOTALPATT:
  123.                 totalpatt = resti;
  124.                 break;
  125.             case CMD_DUMPFILE:
  126.                 strcpy( dname, rest );
  127.                 break;
  128.             case CMD_DUMPIN:
  129.                 strcpy( dinname, rest );
  130.                 break;
  131.             case CMD_TRAINERR:
  132.                 trainerr = atof( rest );
  133.                 break;
  134.             case CMD_REPORT:
  135.                 report = resti;
  136.                 break;
  137.             case CMD_TIMER:
  138.                 timer = 1;
  139.                 break;
  140.             case CMD_TDUMP:
  141.                 tdump = restl;
  142.                 break;
  143.             case CMD_WPOS:
  144.                 UB = atof(rest);
  145.                 break;
  146.             case CMD_WNEG:
  147.                 LB = atof(rest);
  148.                 break;
  149.             case CMD_TOLER:
  150.                 TOLER = atof(rest);
  151.                 break;
  152.             case CMD_COMMENT:
  153.                 break;
  154.             case CMD_NULL:
  155.                 printf( "%s: unknown command [%s]\n", PROGNAME, rest );
  156.                 exit (2);
  157.                 break;
  158.             }
  159.         }
  160.         if (hiddencnt < totalhidden)
  161.             {
  162.             error( NN2MANYHIDDEN );
  163.             }
  164. }
  165.  
  166. int    gettrainvect( tname )
  167. char    *tname;
  168. {
  169.     int    i, j, cnt;
  170.     VECTOR    *tmp;
  171.     FILE    *ft;
  172.  
  173.  
  174.     ft = fopen( tname, "r" );
  175.     if (ft == NULL)
  176.         {
  177.         error( NNTFRERR );
  178.         }
  179.  
  180.     inputvect = malloc( sizeof(VECTOR *) * totalpatt );
  181.     targtvect = malloc( sizeof(VECTOR *) * totalpatt );
  182.  
  183.     if (totalpatt <= 0)
  184.         {
  185.         error( NN2FEWPATT );
  186.         }
  187.     for (i=0; i<totalpatt; i++)
  188.         {
  189.         /*
  190.         *    allocate input patterns
  191.         */
  192.         tmp = v_creat( input );
  193.         for (j=0; j<input; j++)
  194.             {
  195.             cnt = fscanf( ft, "%lf", &tmp->vect[j] );
  196.             if (cnt < 1)
  197.                 {
  198.                 error( NNTFIERR );
  199.                 }
  200.             }
  201.         *(inputvect + i) = tmp;
  202.  
  203.         tmp = v_creat( output );
  204.         for (j=0; j<output; j++)
  205.             {
  206.             cnt = fscanf( ft, "%lf", &tmp->vect[j] );
  207.             if (cnt < 1)
  208.                 {
  209.                 error( NNTFIERR );
  210.                 }
  211.             }
  212.         *(targtvect + i) = tmp;
  213.         }
  214.     fclose( ft );
  215. }
  216.  
  217.  
  218. int    main( argc, argv )
  219. int    argc;
  220. char    ** argv;
  221. {
  222.     NET    *nn;
  223.     FILE    *fdump;
  224.  
  225.     if (argc < 2)
  226.         {
  227.         usage();
  228.         }
  229.     else
  230.         {
  231.         cmdinit( argc, argv );
  232.         parse();
  233.         }
  234.  
  235.     /* create a neural net */
  236.     nn = nn_creat( totalhidden + 1, input, output, hidden );
  237.  
  238.     gettrainvect( tname );
  239.  
  240.     /* read last dump, if any */
  241.     if (*dinname != NULL)
  242.         {
  243.         printf( "%s: opening dump file [%s] ...\n", PROGNAME, dinname);
  244.         if ((fdump = fopen( dinname, "r" )) != NULL)
  245.             {
  246.             nn_load( fdump, nn );
  247.             fclose( fdump );
  248.             }
  249.         }
  250.  
  251.     printf( "%s: start\n", PROGNAME );
  252.  
  253.     if (timer)
  254.         timer_restart();
  255.     /*
  256.     * the default training error, ..., etc can be incorporated into
  257.     * the interface - if you like.
  258.     */
  259.     nnbp_train( nn, inputvect, targtvect, totalpatt, trainerr, ETA_DEFAULT, ALPHA_DEFAULT, report, tdump, dname );
  260.  
  261.     if (timer)
  262.         printf("%s: time elapsed = %ld secs\n", PROGNAME, timer_stop());
  263.  
  264.     printf( "%s: dump neural net to [%s]\n", PROGNAME, dname );
  265.  
  266.     fdump = fopen( dname, "w" );
  267.     nn_dump( fdump, nn );
  268.     fclose(fdump);
  269. }
  270.  
  271.  
  272.